package com.lmx.project.websocket;

import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.github.xiaoymin.knife4j.core.util.StrUtil;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.lmx.project.common.BaseResponse;
import com.lmx.project.common.ResultUtils;
import com.lmx.project.model.dto.chatgpt.ChatModel;
import com.lmx.project.model.entity.Chat;
import com.lmx.project.model.entity.Conversion;
import com.lmx.project.service.ChatService;
import com.lmx.project.service.ConversionService;
import com.lmx.project.until.ChatGptUntil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

@ServerEndpoint(value = "/websocket/{ConversionId}")
@Component
public class WebSocket {

    private static ChatGptUntil chatGptUntil;

    private static ChatService chatService;

    private static ConversionService conversionService;

    @Resource
    public void setConversionService(ConversionService conversionService) {
        WebSocket.conversionService = conversionService;
    }

    @Resource
    public void setChatService(ChatService chatService) {
        WebSocket.chatService = chatService;
    }

    @Resource
    public void setChatGptUntil(ChatGptUntil chatGptUntil) {
        WebSocket.chatGptUntil = chatGptUntil;
    }

    private final static Logger logger = LogManager.getLogger(WebSocket.class);

    /**
     * 静态变量，用来记录当前在线连接数。应该把它设计成线程安全的
     */

    private static int onlineCount = 0;

    /**
     * concurrent包的线程安全Map，用来存放每个客户端对应的MyWebSocket对象
     */
    private static ConcurrentHashMap<String, WebSocket> webSocketMap = new ConcurrentHashMap<>();

    /**
     * 与某个客户端的连接会话，需要通过它来给客户端发送数据
     */

    private Session session;
    private Long ConversionId;


    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("ConversionId") Long ConversionId) {
        this.session = session;
        this.ConversionId = ConversionId;
        //加入map
        webSocketMap.put(ConversionId.toString(), this);
        addOnlineCount();           //在线数加1
        logger.info("对话{}连接成功,当前在线人数为{}", ConversionId, getOnlineCount());
        try {
            sendMessage(String.valueOf(this.session.getQueryString()));
        } catch (IOException e) {
            logger.error("IO异常");
        }
    }


    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        //从map中删除
        webSocketMap.remove(ConversionId.toString());
        subOnlineCount();           //在线数减1
        logger.info("对话{}关闭连接！当前在线人数为{}", ConversionId, getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        logger.info("来自客户端对话：{} 消息:{}", ConversionId, message);


        Gson gson = new Gson();

//        ChatMessage chatMessage = gson.fromJson(message, ChatMessage.class);

        System.out.println(message);

//        Long conversionid = chatMessage.getConversionid();
//        if (conversionid == null) {
//            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是哪个对话");
//            String s = gson.toJson(baseResponse);
//            session.getBasicRemote().sendText(s);
//        }

        if (message == null) {
            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是该对话的用途");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }
//        将对话保存到数据库中
        Chat entity = new Chat();
        entity.setContext(message);
        entity.setConversionid(this.ConversionId);
        entity.setRole("user");
        boolean save = chatService.save(entity);

        if (!save) {
            BaseResponse baseResponse = ResultUtils.error(500, "数据库出现错误");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }


//        查询出身份
        Conversion byId = conversionService.getById(this.ConversionId);
        String instructions = byId.getInstructions();// 指令
//     给予chatgot身份
        ArrayList<ChatModel> chatModels = new ArrayList<>();
//        ChatModel scene = new ChatModel("user", instructions);
//        chatModels.add(scene);

        LambdaQueryWrapper<Chat> queryWrapper = new LambdaQueryWrapper<>();
        // 按照修改时间进行升序排序
        queryWrapper.eq(Chat::getConversionid, byId.getId()).orderByDesc(Chat::getUpdatedtime);
        List<Chat> list = chatService.list(queryWrapper);

//        查询之前的对话记录
        List<ChatModel> collect = list.stream().map(chat -> {
            ChatModel chatModel = new ChatModel();
            chatModel.setRole(chat.getRole());
            chatModel.setContent(chat.getContext());
//            BeanUtils.copyProperties(chat, chatModel);
            return chatModel;
        }).collect(Collectors.toList());
        chatModels.addAll(collect);


        chatGptUntil.getRespost(this.ConversionId, chatModels);
//        if (chatGptUntil==null){
//            System.out.println("chatuntil是空");
//        }
//
//        if (stringRedisTemplate==null){
//            System.out.println("缓存是空");
//        }


        //群发消息
        /*for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }*/
    }

    /**
     * 发生错误时调用
     *
     * @OnError
     */
    @OnError
    public void onError(Session session, Throwable error) {
        logger.error("对话错误:" + this.ConversionId + ",原因:" + error.getMessage());
        error.printStackTrace();
    }

    /**
     * 向客户端发送消息
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
        //this.session.getAsyncRemote().sendText(message);
    }

    /**
     * 通过userId向客户端发送消息
     */
    public void sendMessageByUserId(Long ConversionId, String message) throws IOException {
        logger.info("服务端发送消息到{},消息：{}", ConversionId, message);
        if (StrUtil.isNotBlank(ConversionId.toString()) && webSocketMap.containsKey(ConversionId.toString())) {
            webSocketMap.get(ConversionId.toString()).sendMessage(message);
        } else {
            logger.error("{}不在线", ConversionId);
        }

    }

    /**
     * 群发自定义消息
     */
    public static void sendInfo(String message) {
        for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocket.onlineCount--;
    }

}
